-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Minor Touches for ScoreGradELBO
#99
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
Benchmark suite | Current: 801a76b | Previous: b6083ed | Ratio |
---|---|---|---|
normal/RepGradELBO + STL/meanfield/Zygote |
15540028767 ns |
15262008164 ns |
1.02 |
normal/RepGradELBO + STL/meanfield/ForwardDiff |
3336745945 ns |
3126611341 ns |
1.07 |
normal/RepGradELBO + STL/meanfield/ReverseDiff |
3252943822 ns |
3207060677 ns |
1.01 |
normal/RepGradELBO + STL/fullrank/Zygote |
15109365150 ns |
15220603739 ns |
0.99 |
normal/RepGradELBO + STL/fullrank/ForwardDiff |
3677404700 ns |
3439213244 ns |
1.07 |
normal/RepGradELBO + STL/fullrank/ReverseDiff |
5745669202 ns |
5765830869 ns |
1.00 |
normal/RepGradELBO/meanfield/Zygote |
7223742806 ns |
7257011712 ns |
1.00 |
normal/RepGradELBO/meanfield/ForwardDiff |
2387860875 ns |
2294402627.5 ns |
1.04 |
normal/RepGradELBO/meanfield/ReverseDiff |
1474387789 ns |
1435638209 ns |
1.03 |
normal/RepGradELBO/fullrank/Zygote |
7231934077 ns |
7176706497 ns |
1.01 |
normal/RepGradELBO/fullrank/ForwardDiff |
2566869136 ns |
2493123842.5 ns |
1.03 |
normal/RepGradELBO/fullrank/ReverseDiff |
2581003358 ns |
2529829241 ns |
1.02 |
normal + bijector/RepGradELBO + STL/meanfield/Zygote |
23549114509 ns |
23064269596 ns |
1.02 |
normal + bijector/RepGradELBO + STL/meanfield/ForwardDiff |
10271662690 ns |
9877706080 ns |
1.04 |
normal + bijector/RepGradELBO + STL/meanfield/ReverseDiff |
5166012174 ns |
4966628075 ns |
1.04 |
normal + bijector/RepGradELBO + STL/fullrank/Zygote |
23512263396 ns |
23022827020 ns |
1.02 |
normal + bijector/RepGradELBO + STL/fullrank/ForwardDiff |
10840780388 ns |
10583173397 ns |
1.02 |
normal + bijector/RepGradELBO + STL/fullrank/ReverseDiff |
8241021709 ns |
8173149369 ns |
1.01 |
normal + bijector/RepGradELBO/meanfield/Zygote |
14938818567 ns |
14449178740 ns |
1.03 |
normal + bijector/RepGradELBO/meanfield/ForwardDiff |
9351470129 ns |
8845141905 ns |
1.06 |
normal + bijector/RepGradELBO/meanfield/ReverseDiff |
3197315608 ns |
3004891758 ns |
1.06 |
normal + bijector/RepGradELBO/fullrank/Zygote |
14961186110 ns |
14597576336 ns |
1.02 |
normal + bijector/RepGradELBO/fullrank/ForwardDiff |
9897958470 ns |
9482666704 ns |
1.04 |
normal + bijector/RepGradELBO/fullrank/ReverseDiff |
4617350745 ns |
4449258051 ns |
1.04 |
This comment was automatically generated by workflow using github-action-benchmark.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #99 +/- ##
==========================================
- Coverage 93.54% 92.65% -0.90%
==========================================
Files 12 12
Lines 372 354 -18
==========================================
- Hits 348 328 -20
- Misses 24 26 +2 ☔ View full report in Codecov by Sentry. |
@willtebbutt Hi, it seems only Mooncake is failing. Could you look into the error messages? If they don't quite make sense, I'll try to pack up a MWE. |
Aha! I've been waiting for an example where this happens -- I've been aware of this |
This being said, if you're able to make a MWE that I can run easily on my machine, that would be great. Just whatever the function is that you're differentiating, because I'm not exactly how your tests are structured, and therefore where to find the correct function. |
Sampling is now done out of the AD path.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@arnauqb Hi, I shifted things around a bit. Could you take a look and see if you have any comments? Also, the default value for the baseline, which you set as 1, is somewhat curious. Shouldn't this be 0 for the control variate to have no effect? Or maybe I'm misunderstanding something. |
Hi @Red-Portal thanks for tidying this up, looks good to me. Yes, you are right that baseline should be 0 to have no effect, not sure what I set it to 1 😅 |
@Red-Portal After a bit of testing, I run into an error when sampling from two priors like this: @model make_model(data)
p1 ~ Dirichlet(3, 1.0)
p2 ~ Dirichlet(2, 1.0)
data ~ CustomDistribution(vcat(p1, p2))
end and then transforming When using the score estimator I get an input missmatch To fix this, we can sample |
@arnauqb I think that's how it is done in this PR now. However, that error is quite peculiar, and it appears that there might be something wrong with Bijectors. If you have time, could you try to come up with a MWE and take it to |
The problem may be that I am using this: https://github.com/TuringLang/Turing.jl/blob/40a0d84b76e8e262e32618f83e6b895b34177d95/src/variational/advi.jl#L23 using AdvancedVI
using ADTypes
using DynamicPPL
using DistributionsAD
using Distributions
using Bijectors
using Optimisers
using LinearAlgebra
using Zygote
function wrap_in_vec_reshape(f, in_size)
vec_in_length = prod(in_size)
reshape_inner = Bijectors.Reshape((vec_in_length,), in_size)
out_size = Bijectors.output_size(f, in_size)
vec_out_length = prod(out_size)
reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,))
return reshape_outer ∘ f ∘ reshape_inner
end
function Bijectors.bijector(
model::DynamicPPL.Model, ::Val{sym2ranges} = Val(false);
varinfo = DynamicPPL.VarInfo(model)
) where {sym2ranges}
num_params = sum([size(varinfo.metadata[sym].vals, 1) for sym in keys(varinfo.metadata)])
dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...)
num_ranges = sum([length(varinfo.metadata[sym].ranges)
for sym in keys(varinfo.metadata)])
ranges = Vector{UnitRange{Int}}(undef, num_ranges)
idx = 0
range_idx = 1
# ranges might be discontinuous => values are vectors of ranges rather than just ranges
sym_lookup = Dict{Symbol, Vector{UnitRange{Int}}}()
for sym in keys(varinfo.metadata)
sym_lookup[sym] = Vector{UnitRange{Int}}()
for r in varinfo.metadata[sym].ranges
ranges[range_idx] = idx .+ r
push!(sym_lookup[sym], ranges[range_idx])
range_idx += 1
end
idx += varinfo.metadata[sym].ranges[end][end]
end
bs = map(tuple(dists...)) do d
b = Bijectors.bijector(d)
if d isa Distributions.UnivariateDistribution
b
else
wrap_in_vec_reshape(b, size(d))
end
end
if sym2ranges
return (
Bijectors.Stacked(bs, ranges),
(; collect(zip(keys(sym_lookup), values(sym_lookup)))...)
)
else
return Bijectors.Stacked(bs, ranges)
end
end
##
function double_normal()
return MvNormal([2.0, 3.0, 4.0], Diagonal(ones(3)))
end
@model function normal_model(data)
p1 ~ filldist(Normal(0.0, 1.0), 2)
p2 ~ Normal(0.0, 1.0)
ps = vcat(p1, p2)
for i in 1:size(data, 2)
data[:, i] ~ MvNormal(ps, Diagonal(ones(3)))
end
end
data = rand(double_normal(), 100)
model = normal_model(data)
##
d = 3
μ = zeros(d)
L = Diagonal(ones(d));
q = AdvancedVI.MeanFieldGaussian(μ, L)
optimizer = Optimisers.Adam(1e-3)
bijector_transf = inverse(bijector(model))
q_transformed = transformed(q, bijector_transf)
ℓπ = DynamicPPL.LogDensityFunction(model)
elbo = AdvancedVI.ScoreGradELBO(10, entropy = StickingTheLandingEntropy()) # this doesn't
#elbo = AdvancedVI.RepGradELBO(10, entropy = StickingTheLandingEntropy()) # This works
q, _, stats, _ = AdvancedVI.optimize(
ℓπ,
elbo,
q_transformed,
10;
adtype = AutoZygote(),
optimizer = optimizer,
) and stacktrace:
|
Hi @arnauqb , sorry for the late reply. Seems like this happens only on Zygote right? Edit: Aha! This is because Zygote attempts to differentiate through |
Everything else is passing other than Enzyme tests. |
@Red-Portal I think this line should be moved outside the gradient calculation (and passed through |
Hi! Hmm it wouldn't be using the path gradient here since |
Something similar. I have a custom Distribution for which I implement a Now you are right that samples don't carry a gradient so ultimately this won't affect the results, I think, but may be a performance boost to take it out. |
I see. Uhghh it is annoying that Zygote can't be smarter. Anyways thanks for mentioning this, I'll try to improve it. |
@yebai This one is also ready to go |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the theory here, so just commenting on the software engineering. A few questions about tests, and requesting some more documentation.
AD_scoregradelbo_interface = if TEST_GROUP == "Enzyme" | ||
[AutoEnzyme()] | ||
else | ||
[ | ||
AutoForwardDiff(), | ||
AutoReverseDiff(), | ||
AutoZygote(), | ||
AutoMooncake(; config=Mooncake.Config()), | ||
] | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this gets used a lot, could it be in a test utils module? Orthogonal to this PR though, can be done later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I am planning to iron this out in a separate PR
@mhauru Oh I missed a minor bug. Let me fix this tomorrow and I'll ping you again when I'm done! |
@mhauru Yeah seems ready to go now! |
Thanks @Red-Portal!
Could something be added to the test suite that would have caught this bug? |
Hi @mhauru ! Nah, no need. It was a stupid typo that failed all the tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy on the code side. Thanks @Red-Portal!
Thank you @Red-Portal , this is very useful! |
No description provided.